import numpy as np
from scipy import stats
import scipy.optimize as opt
import time_dependent as td
from menu_cost import transition_matrix, parabolic_max


def backward_iteration(V_p_stacked, Q, x_grid, r, mu, tau, distribution, sig, la, mu_k, sig_k, gap, discount, elast):
    nx = len(x_grid)

    if elast is None:
        profit = -(1+discount)*(x_grid-gap)**2
    else:
        profit = (1+discount)*(elast*np.exp(-(elast-1)*(x_grid-gap)) - (elast-1)*np.exp(-elast*(x_grid-gap)))/(0.5*elast*(elast-1))

    profit = profit[:, np.newaxis] + profit[np.newaxis, :]
    profit_stacked = np.reshape(profit, (nx**2,))

    V_not_stacked = profit_stacked + (1 / (1 + r)) * (Q @ V_p_stacked)

    # i_max_stacked = np.argmax(V_not_stacked)
    # i_max = i_max_stacked // nx

    i_max = np.argmax(V_not_stacked[(nx+1) * np.arange(nx)])
    i_max_stacked = (nx+1) * i_max

    # parabolic interpolation so the maximum can fall between grid points, using the symmetry of the problem
    x_max, V_adj = parabolic_max(x_grid[i_max - 1], x_grid[i_max], x_grid[i_max + 1],
                                 V_not_stacked[i_max_stacked - nx - 1], V_not_stacked[i_max_stacked], V_not_stacked[i_max_stacked + nx + 1])

    # interpolate to find transition for adjusters
    i1 = np.where(x_grid > x_max)[0][0]
    i0 = i1 - 1
    w0 = (x_grid[i1] - x_max) / (x_grid[i1] - x_grid[i0])
    Q_adj = w0 * Q[(nx+1) * i0, :] + (1 - w0) * Q[(nx+1) * i1, :]

    ind_adj = [x_grid[i0], x_grid[i1]]

    # Adjustment probability
    l = np.maximum(V_adj - V_not_stacked, 1e-6)
    l = np.log(l)
    q_stacked = la + (1 - la) * stats.norm.cdf((l - mu_k) / sig_k)

    # Expected menu cost payment
    EK = (1 - la) * (np.exp(mu_k + 0.5 * sig_k ** 2) * stats.norm.cdf((l - mu_k - sig_k ** 2) / sig_k))

    V_stacked = q_stacked * V_adj + (1 - q_stacked) * V_not_stacked - EK

    p_lagged = 0.5 * (x_grid[:, np.newaxis] + x_grid[np.newaxis, :])
    p_lagged = np.reshape(p_lagged, (nx**2,))

    output = {'p': (1 - q_stacked) * p_lagged + q_stacked * x_max - mu}

    return V_stacked, q_stacked, x_max, Q_adj, ind_adj, output


def pol_ss(parameters, x_grid, V_seed=None, tol=1e-6, maxit=10000):
    nx = len(x_grid)

    if V_seed is None:
        V_p_stacked = np.zeros(nx**2)
    else:
        V_p_stacked = V_seed

    mu = parameters['mu']
    sig = parameters['sig']
    tau = parameters['tau']
    distribution = parameters['distribution']

    Q = transition_matrix(-mu, sig, distribution, tau, x_grid)
    Q = np.kron(Q, Q)

    for it in range(maxit):
        V_stacked, q_stacked, x_max, Q_adj, ind_adj, output_ss = backward_iteration(V_p_stacked, Q, x_grid, **parameters)

        diff = np.max(np.abs(V_stacked-V_p_stacked))

        # if it % 20 == 0:
            # print(f'value function iteration error: {diff:.6f}')

        if diff < tol:
            break

        V_p_stacked = V_stacked

    else:
        raise ValueError(f'No convergence after {maxit} backward iterations!')

    Pi = q_stacked[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q_stacked[:, np.newaxis]) * Q

    return V_stacked, Pi, q_stacked, x_max, ind_adj, Q, output_ss


def dist_ss(Pi, x_grid, g_seed=None, tol=1e-12, maxit=100000):
    if g_seed is None:
        nx = len(x_grid)
        g_p = np.ones(nx**2)
        g_p = g_p/np.sum(g_p)
    else:
        g_p = g_seed

    for it in range(maxit):
        g = Pi.T @ g_p

        diff = np.max(np.abs(g-g_p))
        # print(diff)

        if it % 10 and (diff < tol):
            break

        g_p = g

    return g


def steady_state(parameters, x_grid, V_seed=None, g_seed=None, tol_pol=1e-6):

    V_stacked, Pi, q_stacked, x_max, ind_adj, Q, output_ss = pol_ss(parameters, x_grid, V_seed=V_seed, tol=tol_pol)
    g_stacked = dist_ss(Pi, x_grid, g_seed=g_seed)

    stats = statistics(q_stacked, x_max, ind_adj, g_stacked, x_grid)

    return V_stacked, Pi, q_stacked, x_max, stats, g_stacked, Q, output_ss


def backward_iteration_transition(parameters, V_stacked_ss, Q, g_stacked_ss, output_ss, x_grid, T, h, shock, output_list):
    q_transition = list()
    Q_adj_transition = list()
    output_transition = list()

    V_p_stacked = V_stacked_ss

    for t in range(T):
        # print(f'Backward iteration transition: {t+1}')
        if t == 0:
            ss_param_value = parameters[shock]
            parameters[shock] = parameters[shock] + h
        elif t == 1:
            parameters[shock] = ss_param_value

        V_stacked, q_stacked, _, Q_adj, _, output = backward_iteration(V_p_stacked, Q, x_grid, **parameters)
        # Pi = q_stacked[:, np.newaxis] * Q_adj[np.newaxis, :] + (1 - q_stacked[:, np.newaxis]) * Q

        q_transition.append(q_stacked)
        Q_adj_transition.append(Q_adj)
        output_transition.append(output)

        V_p_stacked = V_stacked

    anticipation_effects = {}
    for out in output_list:
        anticipation_effects[out] = np.zeros((T, T))
        for row in range(T):
            for col in range(row, T):
                anticipation_effects[out][row, col] = ((output_transition[col-row][out] - output_ss[out]) / h) @ g_stacked_ss

    q_transition.reverse()
    Q_adj_transition.reverse()
    output_transition.reverse()

    return anticipation_effects, q_transition, Q_adj_transition, output_transition


def forward_iteration_transition(Pi_ss, q_transition, Q_adj_transition, Q, g_stacked_ss, output_ss, x_grid, T, h, output_list):
    nx = len(x_grid)

    history_effects = {}
    H = {}
    for out in output_list:
        history_effects[out] = np.zeros((T, T))
        H[out] = np.zeros((nx ** 2, T))

    for it in range(T):
        # print(f'Forward iteration transition: {it + 1}')
        for out in output_list:
            if it == 0:
                H[out][:, it] = output_ss[out]
            else:
                H[out][:, it] = Pi_ss @ H[out][:, it-1]

    H_ss = H.copy()

    for it in range(T):
        # print(f'Forward iteration transition: {it + 1 + T}')
        Pi = q_transition[T-it-1][:, np.newaxis] * Q_adj_transition[T-it-1][np.newaxis, :] + (1 - q_transition[T-it-1][:, np.newaxis]) * Q
        for out in output_list:
            H[out] = Pi @ H[out]
            H[out][:, 1:T] = H[out][:, 0:T-1]
            H[out][:, 0] = output_ss[out]
            col = (H[out].T - H_ss[out].T) @ g_stacked_ss
            history_effects[out][:, it] = col / h

    return history_effects


def compute_jacobian(parameters, x_grid, T, h, shock_list, output_list):

    V_stacked_ss, Pi_ss, q_stacked_ss, x_max_ss, stats, g_stacked_ss, Q, output_ss = steady_state(parameters, x_grid)

    ss = {'V': V_stacked_ss,
          'Pi': Pi_ss,
          'q': q_stacked_ss,
          'x_max': x_max_ss,
          'stats': stats,
          'g': g_stacked_ss,
          'x_grid': x_grid,
          'output': output_ss
          }

    J = {}
    for out in output_list:
        J[out] = {}
        for shock in shock_list:
            J[out][shock] = np.zeros((T, T))

    for shock in shock_list:
        anticipation_effects, q_transition, Q_adj_transition, _ = backward_iteration_transition(parameters, V_stacked_ss, Q, g_stacked_ss, output_ss, x_grid, T, h, shock, output_list)
        history_effects = forward_iteration_transition(Pi_ss, q_transition, Q_adj_transition, Q, g_stacked_ss, output_ss, x_grid, T, h, output_list)
        for out in output_list:
            J[out][shock] = (anticipation_effects[out] + history_effects[out])

    return J, ss


def statistics(q_stacked, x_max, ind_adj, g_stacked, x_grid):
    nx = len(x_grid)

    x1 = x_grid[:, np.newaxis] + np.zeros((nx, nx))
    x2 = x1.T
    x1 = np.reshape(x1, (nx ** 2,))
    x2 = np.reshape(x2, (nx ** 2,))

    dp1 = x_max - x1
    dp2 = x_max - x2

    q1 = q_stacked.copy()
    q2 = q_stacked.copy()

    if np.isclose(x_max, 0):
        q1[np.isclose(x1, 0)] = 0
        q2[np.isclose(x2, 0)] = 0
    else:
        q1[np.logical_or(np.isclose(x1, ind_adj[0]), np.isclose(x1, ind_adj[1]))] = 0
        q2[np.logical_or(np.isclose(x2, ind_adj[0]), np.isclose(x2, ind_adj[1]))] = 0

    dp = np.concatenate([dp1, dp2])
    q = np.concatenate([q1, q2])
    g = np.concatenate([g_stacked, g_stacked])
    g = g / np.sum(g)

    freq = np.sum(q * g)

    weight = q * g / freq

    abs_dp = np.abs(dp)
    mean_abs_dp = np.sum(abs_dp * weight)

    # sort to obtain median
    abs_dp = np.abs(dp)
    ind = np.argsort(abs_dp)
    abs_dp = np.sort(abs_dp)
    cum_weight = np.cumsum(weight[ind])

    i_med = np.where(cum_weight > 0.5)[0][0]
    med_abs_dp = 0.5 * abs_dp[i_med] + 0.5 * abs_dp[i_med - 1]

    stats = {'freq': freq,
             'mean_abs_dp': mean_abs_dp,
             'med_abs_dp': med_abs_dp
             }

    return stats


def permanent_shock(x_grid, Pi, g_stacked, output_p, T, h):
    if h < 0:
        h = -h

    nx = len(x_grid)
    dx = x_grid[1] - x_grid[0]

    w = 1 - h / dx

    g = np.reshape(g_stacked, (nx, nx))
    g_displaced_partial = np.zeros_like(g)

    g_displaced_partial[:, 0:-1] = w * g[:, 0:-1] + (1 - w) * g[:, 1:]
    g_displaced_partial[:, -1] = w * g[:, -1]
    g_displaced_partial[:, 0] += (1 - w) * g[:, 0]

    g_displaced = np.zeros_like(g)

    g_displaced[0:-1, :] = w * g_displaced_partial[0:-1, :] + (1 - w) * g_displaced_partial[1:, :]
    g_displaced[-1, :] = w * g_displaced_partial[-1, :]
    g_displaced[0, :] += (1 - w) * g_displaced_partial[0, :]

    p = np.zeros(T)
    g_ss_stacked = g_stacked.copy()
    g_stacked = np.reshape(g_displaced, (nx ** 2,))

    for t in range(T):
        p[t] = np.sum(output_p * (g_stacked - g_ss_stacked))
        g_stacked = Pi.T @ g_stacked

    return (p + h) / h


def extend_matrix(X, T_extend):
    T = len(X)
    t_middle = np.int(T / 2)
    middle_column = X[:, t_middle]
    X_large = np.zeros((T_extend, T_extend))
    X_large[:T, :T] = X
    for i in range(T_extend - t_middle):
        t_diag = t_middle + i
        t0 = t_diag - t_middle
        t1 = min(t_diag - t_middle + T, T_extend)
        X_large[t0:t1, t_diag] = middle_column[0:t1 - t0]
    return X_large


def compute_PC(J, x_grid, parameters=None, ss=None, h=None, T_extend=1000, permanent_shock_to_normalize=False, cut_back=False, return_nominal=False):
    T = len(J)
    if T >= T_extend:
        T_extend = T + 100

    if permanent_shock_to_normalize:
        # normalize by permanent shock to make sure far out columns sum to 1
        shock = permanent_shock(x_grid, ss['Pi'], ss['g'], ss['output']['p'], T, h)
        J = J * shock[:, np.newaxis] / np.sum(J, axis=1)[:, np.newaxis]

    # extend matrix to increase precision
    J = extend_matrix(J, T_extend)

    # we want to solve (I-J) @ J_real = J, but I-J is ill-conditioned
    # use the Calvo approximation (without imposing it!) to find a matrix P such that P @ (I-J) is well-conditioned
    # then solve (P @ (I-J)) @ J_real = P @ J or A @ J_real = B
    beta = 1 / (1 + parameters['r'])
    J_calvo = np.cumsum(td.calvo_PC(1, beta, T_extend), axis=0)
    P = np.linalg.solve(J.T, J_calvo.T).T
    # P = np.linalg.solve(J, P)
    A = P @ (np.eye(T_extend) - J)
    B = J_calvo
    # B = np.linalg.solve(J, B)
    J_pc = np.linalg.solve(A, B)
    J_pc[1:, :] -= J_pc[:-1, :]

    if cut_back:
        J_pc = J_pc[0:T][0:T]

    if return_nominal:
        return J_pc, J

    return J_pc


def calibrate_model(targs, parameters, initial_guess, nx=101, xmax=0.25, optimize=True):
    # targs is a dict of moment names and values
    # parameters is a dict of parameters
    # initial_guess is a dict with initial guesses

    x_grid = np.linspace(-xmax, xmax, nx)

    calibrated_parameters = initial_guess.keys()
    x0 = np.array([v for v in initial_guess.values()])

    def difference(x, x_grid):
        par = parameters.copy()
        par.update({k: v for k, v in zip(calibrated_parameters, x)})
        V, Pi, q, x_max, stats, g, Q, output_ss = steady_state(par, x_grid)
        diff = np.array([stats[k] - targs[k] for k in targs.keys()])
        return diff

    if optimize:
        results = opt.least_squares(lambda x: difference(x, x_grid), x0=x0)
        x = results.x
    else:
        x = x0
    par_out = parameters.copy()
    par_out.update({k: v for k, v in zip(calibrated_parameters, x)})
    V, Pi, q, x_max, stats, g, Q, output_ss = steady_state(par_out, x_grid)

    ss = {'V': V,
          'Pi': Pi,
          'q': q,
          'x_max': x_max,
          'stats': stats,
          'g': g,
          'Q': Q,
          'x_grid': x_grid,
          'output': output_ss
          }

    all_close = np.all([np.abs(stats[k] - targs[k]) < 1e-4 for k in targs.keys()])
    if not all_close:
        print('Calibration unsuccessful.')

    return par_out, ss
